{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "61b87abe-0ee0-4b93-abaa-eb2087b9ab97",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn.functional import gelu, cosine_similarity\n",
    "from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "487f0884-ba1f-4841-bc11-a4bd512cd51a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class InBedder():\n",
    "    \n",
    "    def __init__(self, path='KomeijiForce/inbedder-roberta-large', device='cuda:0'):\n",
    "        \n",
    "        model = AutoModelForMaskedLM.from_pretrained(path)\n",
    "    \n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(path)\n",
    "        self.model = model.roberta\n",
    "        self.dense = model.lm_head.dense\n",
    "        self.layer_norm = model.lm_head.layer_norm\n",
    "        \n",
    "        self.device = torch.device(device)\n",
    "        self.model = self.model.to(self.device)\n",
    "        self.dense = self.dense.to(self.device)\n",
    "        self.layer_norm = self.layer_norm.to(self.device)\n",
    "        \n",
    "        self.vocab = self.tokenizer.get_vocab()\n",
    "        self.vocab = {self.vocab[key]:key for key in self.vocab}\n",
    "        \n",
    "    def encode(self, input_texts, instruction, n_mask):\n",
    "        \n",
    "        if type(instruction) == str:\n",
    "            prompts = [instruction + self.tokenizer.mask_token*n_mask for input_text in input_texts]\n",
    "        elif type(instruction) == list:\n",
    "            prompts = [inst + self.tokenizer.mask_token*n_mask for inst in instruction]\n",
    "    \n",
    "        inputs = self.tokenizer(input_texts, prompts, padding=True, truncation=True, return_tensors='pt').to(self.device)\n",
    "\n",
    "        mask = inputs.input_ids.eq(self.tokenizer.mask_token_id)\n",
    "        \n",
    "        outputs = self.model(**inputs)\n",
    "\n",
    "        logits = outputs.last_hidden_state[mask]\n",
    "        \n",
    "        logits = self.layer_norm(gelu(self.dense(logits)))\n",
    "        \n",
    "        logits = logits.reshape(len(input_texts), n_mask, -1)\n",
    "        \n",
    "        logits = logits.mean(1)\n",
    "            \n",
    "        logits = (logits - logits.mean(1, keepdim=True)) / logits.std(1, keepdim=True)\n",
    "        \n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bd41d329-f60b-4644-9fb3-0806a2f39dd6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lepeng/miniconda3/envs/komeiji/lib/python3.8/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    }
   ],
   "source": [
    "inbedder = InBedder(path='KomeijiForce/inbedder-roberta-large', device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "325508e4-c761-4aa1-ab61-142c205fbd22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9374, 0.9917], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "texts = [\"I love cat!\", \"I love dog!\", \"I dislike cat!\"]\n",
    "instruction = \"What is the animal mentioned here?\"\n",
    "embeddings = inbedder.encode(texts, instruction, 3)\n",
    "\n",
    "cosine_similarity(embeddings[:1], embeddings[1:], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "417c8221-10cf-4fd2-acac-4041834fdce6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9859, 0.8537], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "texts = [\"I love cat!\", \"I love dog!\", \"I dislike cat!\"]\n",
    "instruction = \"What is emotion expressed here?\"\n",
    "embeddings = inbedder.encode(texts, instruction, 3)\n",
    "\n",
    "cosine_similarity(embeddings[:1], embeddings[1:], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c493d3-19e7-42a6-a217-10cd9ca28318",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
